#!/usr/bin/env python3

"""
Downloads latest Wycheproof test vectors and generates C test code.
"""

from contextlib import redirect_stdout
from dataclasses import dataclass
import datetime
import requests
import os
from pathlib import Path
import sys


@dataclass
class EddsaVerify:
    tcId: int
    comment: str
    msg: bytes
    sig: bytes
    pub: bytes
    ok: bool


def _gen_ed25519():
    req = requests.get(
        "https://raw.githubusercontent.com/google/wycheproof/master/testvectors/eddsa_test.json"
    )
    assert req.status_code == 200
    file = req.json()
    assert file["algorithm"] == "EDDSA"
    assert file["schema"] == "eddsa_verify_schema.json"
    verify_tests = []
    for group in file["testGroups"]:
        if group["type"] != "EddsaVerify":
            print(f"Skipping {group['type']} test", file=sys.stderr)
            continue
        pubkey = bytes.fromhex(group["key"]["pk"])
        for test in group["tests"]:
            verify_tests.append(
                EddsaVerify(
                    tcId=test["tcId"],
                    comment=test["comment"],
                    msg=bytes.fromhex(test["msg"]),
                    sig=bytes.fromhex(test["sig"]),
                    pub=pubkey,
                    ok=test["result"] == "valid",
                )
            )

    print("/* Code generated by gen_wycheproofs.py. DO NOT EDIT. */")
    print(
        f"/* Generated at {datetime.datetime.now(datetime.timezone.utc).isoformat()} */"
    )
    print(
        """
#include "../fd_ballet_base.h"

struct fd_ed25519_verify_wycheproof {
  char const *  comment;
  uchar const * msg;
  ulong         msg_sz;
  uchar         pub[32];
  uchar         sig[64];
  uint          tc_id;
  int           ok;
};

typedef struct fd_ed25519_verify_wycheproof fd_ed25519_verify_wycheproof_t;

static fd_ed25519_verify_wycheproof_t const ed25519_verify_wycheproofs[] = {"""
    )

    for test in verify_tests:
        if len(test.sig) != 64:
            continue
        print(f"  {{ .tc_id   = {test.tcId},")
        print(f'    .comment = "{test.comment}",')
        print(
            r'    .msg     = (uchar const *)"'
            + "".join([r"\x%02x" % (x) for x in test.msg])
            + r'",'
        )
        print(f"    .msg_sz  = {len(test.msg)}UL,")
        print(
            r'    .sig     = "' + "".join([r"\x%02x" % (x) for x in test.sig]) + r'",'
        )
        print(
            r'    .pub     = "' + "".join([r"\x%02x" % (x) for x in test.pub]) + r'",'
        )
        print(f"    .ok      = {1 if test.ok else 0} }},")

    print(r"  {0}")
    print(r"};")


@dataclass
class XDHVerify:
    tcId: int
    comment: str
    shared: bytes
    prv: bytes
    pub: bytes
    ok: bool

def _gen_x25519():
    req = requests.get(
        "https://raw.githubusercontent.com/google/wycheproof/master/testvectors/x25519_test.json"
    )
    assert req.status_code == 200
    file = req.json()
    assert file["algorithm"] == "XDH"
    assert file["schema"] == "xdh_comp_schema.json"
    verify_tests = []
    for group in file["testGroups"]:
        if group["type"] != "XdhComp":
            print(f"Skipping {group['type']} test", file=sys.stderr)
            continue
        for test in group["tests"]:
            verify_tests.append(
                XDHVerify(
                    tcId=test["tcId"],
                    comment=test["comment"],
                    shared=bytes.fromhex(test["shared"]),
                    prv=bytes.fromhex(test["private"]),
                    pub=bytes.fromhex(test["public"]),
                    ok="ZeroSharedSecret" not in set(test["flags"]),
                )
            )

    print("/* Code generated by gen_wycheproofs.py. DO NOT EDIT. */")
    print(
        f"/* Generated at {datetime.datetime.now(datetime.timezone.utc).isoformat()} */"
    )
    print(
        """
#include "../fd_ballet_base.h"

struct fd_x25519_verify_wycheproof {
  char const *  comment;
  uchar         shared[32];
  uchar         prv[32];
  uchar         pub[32];
  uint          tc_id;
  int           ok;
};

typedef struct fd_x25519_verify_wycheproof fd_x25519_verify_wycheproof_t;

static fd_x25519_verify_wycheproof_t const x25519_verify_wycheproofs[] = {"""
    )

    for test in verify_tests:
        print(f"  {{ .tc_id   = {test.tcId},")
        print(f'    .comment = "{test.comment}",')
        print(
            r'    .shared  = "' + "".join([r"\x%02x" % (x) for x in test.shared]) + r'",'
        )
        print(
            r'    .prv     = "' + "".join([r"\x%02x" % (x) for x in test.prv]) + r'",'
        )
        print(
            r'    .pub     = "' + "".join([r"\x%02x" % (x) for x in test.pub]) + r'",'
        )
        print(f"    .ok      = {1 if test.ok else 0} }},")

    print(r"  {0}")
    print(r"};")


def _gen_cctv_ed25519():
    req = requests.get(
        "https://raw.githubusercontent.com/C2SP/CCTV/main/ed25519/ed25519vectors.json"
    )
    assert req.status_code == 200
    file = req.json()
    verify_tests = []
    for test in file:
        flags = test["flags"]
        if flags:
            set_flags = set(flags)
            ok = not(flags) or set_flags == set(["low_order_component_A"]) or set_flags == set(["low_order_component_A", "low_order_component_R"])
            if "non_canonical_R" in set_flags and "low_order_R" not in set_flags:
                raise Exception(test["number"])
        else:
            ok = True
        verify_tests.append(
            EddsaVerify(
                tcId=test["number"],
                comment=test["msg"],
                msg=bytes(test["msg"], 'utf-8'),
                sig=bytes.fromhex(test["sig"]),
                pub=bytes.fromhex(test["key"]),
                ok=ok,  # we implement dalek verify_strict, so all these should fail
            )
        )

    print("/* Code generated by gen_wycheproofs.py. DO NOT EDIT. */")
    print(
        f"/* Generated at {datetime.datetime.now(datetime.timezone.utc).isoformat()} */"
    )
    print(
        """
#include "../fd_ballet_base.h"

struct fd_ed25519_verify_cctv {
  char const *  comment;
  uchar const * msg;
  ulong         msg_sz;
  uchar         pub[32];
  uchar         sig[64];
  uint          tc_id;
  int           ok;
};

typedef struct fd_ed25519_verify_cctv fd_ed25519_verify_cctv_t;

static fd_ed25519_verify_cctv_t const ed25519_verify_cctvs[] = {"""
    )

    for test in verify_tests:
        if len(test.sig) != 64:
            continue
        print(f"  {{ .tc_id   = {test.tcId},")
        print(f'    .comment = "{test.comment}",')
        print(
            r'    .msg     = (uchar const *)"'
            + "".join([r"\x%02x" % (x) for x in test.msg])
            + r'",'
        )
        print(f"    .msg_sz  = {len(test.msg)}UL,")
        print(
            r'    .sig     = "' + "".join([r"\x%02x" % (x) for x in test.sig]) + r'",'
        )
        print(
            r'    .pub     = "' + "".join([r"\x%02x" % (x) for x in test.pub]) + r'",'
        )
        print(f"    .ok      = {1 if test.ok else 0} }},")

    print(r"  {0}")
    print(r"};")

def main():
    with open("src/ballet/ed25519/test_ed25519_wycheproof.c", "w") as out:
        with redirect_stdout(out):
            _gen_ed25519()

    with open("src/ballet/ed25519/test_x25519_wycheproof.c", "w") as out:
        with redirect_stdout(out):
            _gen_x25519()

    with open("src/ballet/ed25519/test_ed25519_cctv.c", "w") as out:
        with redirect_stdout(out):
            _gen_cctv_ed25519()

if __name__ == "__main__":
    root = Path(__file__).parents[2]
    os.chdir(root)
    main()
